import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from cat_rfs.code.utils.graph import convert_to_percentages
from cat_rfs.code.utils.tools import keep_perils
from cat_rfs.code.utils.tools import save_stat

SETTINGS = {'main': {'yield_var': 'sheet_yield_usd',
                     'wsst': False,
                     'fig_file': 'fig3',
                     'perils': None},
            'trace': {'yield_var': 'act_yield_usd',
                     'wsst': False,
                     'fig_file': 'fig3_trace',
                     'perils': None},
            'wsst': {'yield_var': 'sheet_yield_usd',
                     'wsst': True,
                     'fig_file': 'fig3_wsst',
                     'perils': None},
            'hurricane': {'yield_var': 'sheet_yield_usd',
                          'wsst': True,
                          'fig_file': 'fig3_hurricane',
                          'perils': 'Cyclone'},
            'earthquake': {'yield_var': 'sheet_yield_usd',
                           'wsst': True,
                           'fig_file': 'fig3_earthquake',
                           'perils': 'Earthquake'}}


def print_fig3(df, output='console', setting='main'):
    """
    Print Figure 3: Time series evolution of price of natural disaster risk.
    """
    fname = SETTINGS[setting]['fig_file']
    yield_ = SETTINGS[setting]['yield_var']

    df = df.copy()
    df = df[df['date'].dt.quarter == 2]  # Jun only

    if SETTINGS[setting]['wsst']:
        for i in ['attachment_prob', 'expected_loss', 'exhaustion_prob']:
            df = df.drop(columns=[i]).rename(columns={i+'_wsst': i})

    if yield_ == 'act_yield_usd':
        df = df[(df['date'] >= '12/31/2004') & (df['date'] <= '12/31/2018')]

    df = df.rename(columns={yield_: 'yield'})
    df = df.dropna(subset=['yield'])

    if SETTINGS[setting]['perils']:
        df['perilsc'] = df['perils'].copy()
        df = keep_perils(df)
        # This keeps bonds exposed to multiple regions as long as the peril type is the same
        df = df[df['perils'] == SETTINGS[setting]['perils']].copy()
        df['perils'] = df['perilsc'].copy()
        df = df.drop(columns=['perilsc'])

    df['er'] = df['yield'] - df['expected_loss'] / 100 - df['rf'] / 100

    df['er_m'] = df['er'] * df['size']
    df = df.groupby('date', as_index=False).agg({'er_m': 'sum','size': 'sum', 'var_m_sheet': 'first'})
    df['er_m'] = df['er_m'] / df['size']

    df2 = pd.read_csv('cat_rfs/data/ts.csv', parse_dates=['date'])
    df2['lev'] = df2['alternative_capital'] / df2['aum']

    df = df.merge(df2, on='date')

    df['ratio'] = (df['er_m'] / df['var_m_sheet'])

    df = df[df['lev'].notnull()]

    rho = df['ratio'].mean() / df['lev'].mean()
    df['ratio_pred'] = df['lev'] * rho

    if setting == 'main':
        FRAC_OWNED = 0.5
        save_stat(FRAC_OWNED * 100, 'ownership_assumption', formatting='%.0f', output=output)
        save_stat(rho, 'crra_raw', formatting='%.1f', output=output)
        save_stat(rho / FRAC_OWNED, 'crra', formatting='%.1f', output=output)
        save_stat(df[['ratio', 'ratio_pred']].corr().iloc[0, 1], 'corr_ts', formatting='%.2f', output=output)

    # Plot
    sns.set_palette('Greys_r')

    df = df.set_index('date')
    df2 = df.resample('D').interpolate(method='linear')
    df2 = df2[['ratio', 'ratio_pred']].copy()
    if output == 'presentation':
         df2 = df2.rename(columns={'ratio': 'Observed premium', 'ratio_pred': 'Predicted premium'})
    else:
        df2 = df2.rename(columns={'ratio': 'Observed premium $\\left(\\frac{E_t\\left(R^e_{cat}\\right)}{\\mathrm{Var}_t\\left(R^e_{cat}\\right)}\\right)$',
                                  'ratio_pred': 'Predicted premium $\\left(\\hat{\\rho}\\frac{Size_t}{AUM_t}\\right)$'})
    ax = df2.plot(style=['k-', 'k--'], figsize=(6.5, 3.5))

    DISASTERS = [[pd.to_datetime('7/1/2005'), pd.to_datetime('6/30/2006')],  # Katrina
                 [pd.to_datetime('7/1/2008'), pd.to_datetime('6/30/2009')],  # Ike, Lehman
                 [pd.to_datetime('7/1/2010'), pd.to_datetime('6/30/2011')],  # Tohoku
                 [pd.to_datetime('7/1/2012'), pd.to_datetime('6/30/2013')],  # Patricia
                 [pd.to_datetime('7/1/2015'), pd.to_datetime('6/30/2016')],  # Patricia
                 [pd.to_datetime('7/1/2017'), pd.to_datetime('6/30/2018')]]  # HIM
    i = DISASTERS[0]
    ax.axvspan(i[0], i[1], color='lightgray', label='Qualifying disaster(s)')

    for i in DISASTERS[1:]:
        ax.axvspan(i[0], i[1], color='lightgray')

    ax.legend(loc='upper right')
    ax.set_ylim(bottom=0)
    ax.set_xlim(df2.index.min(), df2.index.max())
    ax.set_ylabel('Premium on natural disaster risk')
    ax.set_xlabel('')

    if output == 'paper':
        plt.savefig(f'cat_rfs/output/figures/{fname}.eps')
        plt.close()
    else:
        plt.show()

    # er var_re_m separately
    df2 = df.resample('D').interpolate(method='linear')

    df3 = df2[['var_m_sheet']].copy()
    df2 = df2[['er_m']].copy()

    ax = df2.plot(style=['k-'], figsize=(6.5, 3.5))

    DISASTERS = [[pd.to_datetime('7/1/2005'), pd.to_datetime('6/30/2006')],  # Katrina
                 [pd.to_datetime('7/1/2008'), pd.to_datetime('6/30/2009')],  # Ike, Lehman
                 [pd.to_datetime('7/1/2010'), pd.to_datetime('6/30/2011')],  # Tohoku
                 [pd.to_datetime('7/1/2012'), pd.to_datetime('6/30/2013')],  # Patricia
                 [pd.to_datetime('7/1/2015'), pd.to_datetime('6/30/2016')],  # Patricia
                 [pd.to_datetime('7/1/2017'), pd.to_datetime('6/30/2018')]]  # HIM
    i = DISASTERS[0]
    ax.axvspan(i[0], i[1], color='lightgray', label='Qualifying disaster(s)')

    for i in DISASTERS[1:]:
        ax.axvspan(i[0], i[1], color='lightgray')
    ax = convert_to_percentages(ax)
    ax.legend(loc='upper right')
    ax.set_ylim(bottom=0)
    ax.set_xlim(df2.index.min(), df2.index.max())

    ax2 = df3.plot(style=['k--'], secondary_y=True, ax=ax, figsize=(6.5, 3.5))

    ax.legend([ax.lines[0], ax2.lines[0], ax.patches[0]],
              ['Expected return $\\left(E_t\\left(R^e_{cat}\\right)\\right)$ (left)',
               'Variance $\\left(\\mathrm{Var}_t\\left(R^e_{cat}\\right)\\right)$ (right)',
               'Qualifying disaster(s)'], loc='lower center')

    ax.set_xlabel('')

    if output == 'paper':
        plt.savefig(f'cat_rfs/output/figures/{fname}_b.eps')
        plt.close()
    else:
        plt.show()

    pass
